import time
import sys
import argparse
import random
import torch
import gc
import pickle
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
from utils.metric import get_ner_fmeasure
from model.LGN import Graph,Graph_BiLSTM,Graph_Trans


# 参数解析的工具
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

# 学习率
def lr_decay(optimizer, epoch, decay_rate, init_lr):
    lr = init_lr * ((1 - decay_rate) ** epoch)
    print(" Learning rate is setted as:", lr)
    for param_group in optimizer.param_groups:
        if param_group['name'] == 'aggr':
            param_group['lr'] = lr * 2.
        else:
            param_group['lr'] = lr
    return optimizer

# 数据初始化
def data_initialization(data, word_file, train_file, dev_file, test_file):
    data.build_word_file(word_file)

    if train_file:
        data.build_alphabet(train_file)
        data.build_word_alphabet(train_file)
    if dev_file:
        data.build_alphabet(dev_file)
        data.build_word_alphabet(dev_file)
    if test_file:
        data.build_alphabet(test_file)
        data.build_word_alphabet(test_file)
    return data


def predict_check(pred_variable, gold_variable, mask_variable):
    pred = pred_variable.cpu().data.numpy()
    gold = gold_variable.cpu().data.numpy()
    mask = mask_variable.cpu().data.numpy()
    overlaped = (pred == gold)
    right_token = np.sum(overlaped * mask)
    total_token = mask.sum()
    return right_token, total_token


def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet):
    batch_size = gold_variable.size(0)
    seq_len = gold_variable.size(1)
    mask = mask_variable.cpu().data.numpy()
    pred_tag = pred_variable.cpu().data.numpy()
    gold_tag = gold_variable.cpu().data.numpy()
    pred_label = []
    gold_label = []

    for idx in range(batch_size):
        pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
        gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
        assert (len(pred) == len(gold))
        pred_label.append(pred)
        gold_label.append(gold)

    return pred_label, gold_label

# 打印参数
def print_args(args):
    print("CONFIG SUMMARY:")
    print("     Batch size: %s" % (args.batch_size))
    print("     If use GPU: %s" % (args.use_gpu))
    print("     If use CRF: %s" % (args.use_crf))
    print("     Epoch  number: %s" % (args.num_epoch))
    print("     Learning rate: %s" % (args.lr))
    print("     L2 normalization rate: %s" % (args.weight_decay))
    print("     If use edge embedding: %s" % (args.use_edge))
    print("     If  use  global  node: %s" % (args.use_global))
    print("     Bidirectional digraph: %s" % (args.bidirectional))
    print("     Update   step  number: %s" % (args.iters))
    print("     Attention  dropout   rate: %s" % (args.tf_drop_rate))
    print("     Embedding  dropout   rate: %s" % (args.emb_drop_rate))
    print("     Hidden  state   dimension: %s" % (args.hidden_dim))
    print("     Learning rate decay ratio: %s" % (args.lr_decay))
    print("     Aggregation module dropout rate: %s" % (args.cell_drop_rate))
    print("     Head    number   of   attention: %s" % (args.num_head))
    print("     Head  dimension   of  attention: %s" % (args.head_dim))
    print("CONFIG SUMMARY END.")
    sys.stdout.flush()

# 评估结果
def evaluate(data, args, model, name, encoding_type):
    if name == "train":
        instances = data.train_Ids
    elif name == "dev":
        instances = data.dev_Ids
    elif name == 'test':
        instances = data.test_Ids
    elif name == 'raw':
        instances = data.raw_Ids
    else:
        print("Error: wrong evaluate name,", name)
        exit(0)

    pred_results = []
    gold_results = []

    # set model in eval model
    model.eval()
    batch_size = args.batch_size
    start_time = time.time()
    train_num = len(instances)
    total_batch = train_num // batch_size + 1

    for batch_id in range(total_batch):
        start = batch_id * batch_size
        end = (batch_id + 1) * batch_size
        if end > train_num:
            end = train_num
        instance = instances[start:end]
        if not instance:
            continue

        word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)
        _, tag_seq = model(word_list, batch_char, mask)

        pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet)

        pred_results += pred_label
        gold_results += gold_label

    decode_time = time.time() - start_time
    speed = len(instances) / decode_time
    # type:标注方法[BMES,BIO]
    acc, p, r, f = get_ner_fmeasure(gold_results, pred_results, encoding_type)
    return speed, acc, p, r, f, pred_results

# 对数据进行批次化处理
def batchify_with_label(input_batch_list, gpu):
    batch_size = len(input_batch_list)
    chars = [sent[0] for sent in input_batch_list]
    words = [sent[1] for sent in input_batch_list]
    labels = [sent[2] for sent in input_batch_list]

    sent_lengths = torch.LongTensor(list(map(len, chars)))
    max_sent_len = sent_lengths.max()
    char_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()
    label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_sent_len))).long()
    mask = autograd.Variable(torch.zeros((batch_size, max_sent_len))).byte()

    for idx, (seq, label, seq_len) in enumerate(zip(chars, labels, sent_lengths)):
        char_seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
        label_seq_tensor[idx, :seq_len] = torch.LongTensor(label)
        mask[idx, :seq_len] = torch.Tensor([1] * int(seq_len))

    if gpu:
        char_seq_tensor = char_seq_tensor.cuda()
        label_seq_tensor = label_seq_tensor.cuda()
        mask = mask.cuda()

    return words, char_seq_tensor, label_seq_tensor, mask.bool()

# 训练
def train(model_type,data, args, saved_model_path, encoding_type):
    print("Training model...")
    # 选择模型
    if model_type =="Graph":
        model = Graph(data, args)
    elif model_type =="GraphBiLSTM":
        model = Graph_BiLSTM(data, args)
    elif model_type =="GraphTrans":
        model = Graph_Trans(data, args)
    # GPU
    if args.use_gpu:
        model = model.cuda()
    print('模型参数数量:', sum(param.numel() for param in model.parameters()))
    print("Finished built model.")
    # 记录实验(针对DEV)的最优数据
    best_dev_epoch = 0
    best_dev_f = -1
    best_dev_p = -1
    best_dev_r = -1
    best_test_f = -1
    best_test_p = -1
    best_test_r = -1
    # ---- 优化器 ----
    # 初始化optimizer
    aggr_module_params = []
    other_module_params = []
    for m_name in model._modules:
        m = model._modules[m_name]
        if isinstance(m, torch.nn.ModuleList):
            for p in m.parameters():
                if p.requires_grad:
                    aggr_module_params.append(p)
        else:
            for p in m.parameters():
                if p.requires_grad:
                    other_module_params.append(p)
    optimizer = optim.Adam([
        {"params": (aggr_module_params), "name": "aggr"},
        {"params": (other_module_params), "name": "other"}
    ],
        lr=args.lr,
        weight_decay=args.weight_decay
    )
    # ---- 开始训练 ----
    # 对于每个轮次
    for idx in range(args.num_epoch):
        epoch_start = time.time()
        temp_start = epoch_start
        print(("Epoch: %s/%s" % (idx, args.num_epoch)))
        optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr)
        sample_loss = 0
        batch_loss = 0
        total_loss = 0
        right_token = 0
        whole_token = 0
        random.shuffle(data.train_Ids)
        # set model in train model
        model.train()
        model.zero_grad()
        batch_size = args.batch_size
        train_num = len(data.train_Ids)
        total_batch = train_num // batch_size + 1
        # 对于每个批次
        for batch_id in range(total_batch):
            # Get one batch-sized instance
            start = batch_id * batch_size
            end = (batch_id + 1) * batch_size
            if end > train_num:
                end = train_num
            instance = data.train_Ids[start:end]
            if not instance:
                continue

            word_list, batch_char, batch_label, mask = batchify_with_label(instance, args.use_gpu)
            # 调用模型的 forward() 方法
            #   返回:损失+CRF预测的结果
            loss, tag_seq = model(word_list, batch_char, mask, batch_label)
            right, whole = predict_check(tag_seq, batch_label, mask)
            right_token += right
            whole_token += whole
            sample_loss += loss.data
            total_loss += loss.data
            batch_loss += loss

            if end % 500 == 0:
                temp_time = time.time()
                temp_cost = temp_time - temp_start
                temp_start = temp_time
                print(("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
                       (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)))
                sys.stdout.flush()
                sample_loss = 0
            if end % args.batch_size == 0:
                batch_loss.backward()
                optimizer.step()
                model.zero_grad()
                batch_loss = 0

        temp_time = time.time()
        temp_cost = temp_time - temp_start
        print(("     Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f" %
               (end, temp_cost, sample_loss, right_token, whole_token, (right_token + 0.) / whole_token)))
        epoch_finish = time.time()
        epoch_cost = epoch_finish - epoch_start
        print(("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s,  total loss: %s" %
               (idx, epoch_cost, train_num / epoch_cost, total_loss)))

        # dev
        speed, acc, dev_p, dev_r, dev_f, _ = evaluate(data, args, model, "dev", encoding_type)
        dev_finish = time.time()
        dev_cost = dev_finish - epoch_finish

        print(("Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
               (dev_cost, speed, acc, dev_p, dev_r, dev_f)))

        # test
        speed, acc, test_p, test_r, test_f, _ = evaluate(data, args, model, "test", encoding_type)
        test_finish = time.time()
        test_cost = test_finish - dev_finish

        print(("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
               (test_cost, speed, acc, test_p, test_r, test_f)))

        if dev_f > best_dev_f:
            print("Exceed previous best f score: %.4f" % best_dev_f)
            torch.save(model.state_dict(), saved_model_path + "_best")
            best_dev_p = dev_p
            best_dev_r = dev_r
            best_dev_f = dev_f
            best_dev_epoch = idx + 1
            best_test_p = test_p
            best_test_r = test_r
            best_test_f = test_f

        model_idx_path = saved_model_path + "_" + str(idx)
        torch.save(model.state_dict(), model_idx_path)
        with open(saved_model_path + "_result.txt", "a") as file:
            file.write(model_idx_path + '\n')
            file.write("Dev score: %.4f, r: %.4f, f: %.4f\n" % (dev_p, dev_r, dev_f))
            file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (test_p, test_r, test_f))
            file.close()

        print("Best dev epoch: %d" % best_dev_epoch)
        print("Best dev score: p: %.4f, r: %.4f, f: %.4f" % (best_dev_p, best_dev_r, best_dev_f))
        print("Best test score: p: %.4f, r: %.4f, f: %.4f" % (best_test_p, best_test_r, best_test_f))

        gc.collect()

    with open(saved_model_path + "_result.txt", "a") as file:
        file.write("Best epoch: %d" % best_dev_epoch + '\n')
        file.write("Best Dev score: %.4f, r: %.4f, f: %.4f\n" % (best_dev_p, best_dev_r, best_dev_f))
        file.write("Test score: %.4f, r: %.4f, f: %.4f\n\n" % (best_test_p, best_test_r, best_test_f))
        file.close()

    with open(saved_model_path + "_best_HP.config", "wb") as file:
        pickle.dump(args, file)

# 加载训练好的模型,进行解码操作
def load_model_decode(model_dir, data, args, name, encoding_type):
    model_dir = model_dir + "_best"
    print("Load Model from file: ", model_dir)
    model = Graph(data, args)
    model.load_state_dict(torch.load(model_dir))

    # load model need consider if the model trained in GPU and load in CPU, or vice versa
    if args.use_gpu:
        model = model.cuda()

    print(("Decode %s data ..." % name))
    start_time = time.time()
    speed, acc, p, r, f, pred_results = evaluate(data, args, model, name, encoding_type)
    end_time = time.time()
    time_cost = end_time - start_time
    print(("%s: time:%.2fs, speed:%.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" %
           (name, time_cost, speed, acc, p, r, f)))
    return pred_results